{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# One Hundred Layers Tiramisu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initial Conv Block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class _FirstConv(nn.Sequential):\n",
    "    def __init__(self, num_input_features):\n",
    "        super(_FirstConv, self).__init__()\n",
    "        self.add_module('conv0', nn.Conv2d(3, num_input_features, kernel_size=7, stride=2, padding=3, bias=False))\n",
    "        self.add_module('norm0', nn.BatchNorm2d(num_input_features))\n",
    "        self.add_module('relu0', nn.ReLU(inplace=True))\n",
    "        self.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Test\n",
    "conv1 = _FirstConv(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dense Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class _DenseLayer(nn.Sequential):\n",
    "    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):\n",
    "        super(_DenseLayer, self).__init__()\n",
    "        self.add_module('norm.1', nn.BatchNorm2d(num_input_features)),\n",
    "        self.add_module('relu.1', nn.ReLU(inplace=True)),\n",
    "        self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size *\n",
    "                        growth_rate, kernel_size=1, stride=1, bias=False)),\n",
    "        self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),\n",
    "        self.add_module('relu.2', nn.ReLU(inplace=True)),\n",
    "        self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,\n",
    "                                            kernel_size=3, stride=1, padding=1, bias=False)),\n",
    "        self.drop_rate = drop_rate\n",
    "\n",
    "    def forward(self, x):\n",
    "        new_features = super(_DenseLayer, self).forward(x)\n",
    "        if self.drop_rate > 0:\n",
    "            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)\n",
    "        return torch.cat([x, new_features], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Test\n",
    "dense1 = _DenseLayer(5,3,1,.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dense Block"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class _DenseBlock(nn.Sequential):\n",
    "    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):\n",
    "        super(_DenseBlock, self).__init__()\n",
    "        for i in range(num_layers):\n",
    "            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)\n",
    "            self.add_module('denselayer%d' % (i + 1), layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Test\n",
    "denseBlock1 = _DenseBlock(4,4,1,4,.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transition Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class _TransitionUp(nn.Sequential):\n",
    "    def __init__(self, num_input_features, num_output_features):\n",
    "        super(_TransitionUp, self).__init__()\n",
    "        self.add_module('norm', nn.BatchNorm2d(num_input_features))\n",
    "        self.add_module('relu', nn.ReLU(inplace=True))\n",
    "        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,\n",
    "                                          kernel_size=1, stride=1, bias=False))\n",
    "        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Test\n",
    "transUp = _TransitionUp(5,10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transition Down"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class _TransitionDown(nn.Sequential):\n",
    "    def __init__(self, num_input_features, num_output_features):\n",
    "        super(_TransitionDown, self).__init__()\n",
    "        self.add_module('norm', nn.BatchNorm2d(num_input_features))\n",
    "        self.add_module('relu', nn.ReLU(inplace=True))\n",
    "        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,\n",
    "                                          kernel_size=1, stride=1, bias=False))\n",
    "        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Test\n",
    "transDown = _TransitionDown(5,10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Final Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class FCDenseNet(nn.Module):\n",
    "    r\"\"\"FC-DenseNet model class, based on\n",
    "    `\"The One Hundred Layers Tiramisu: Fully Convolutional DenseNets for Semantic Segmentation\" <https://arxiv.org/pdf/1611.09326>`\n",
    "\n",
    "    Args:\n",
    "        growth_rate (int) - how many filters to add each layer (`k` in paper)\n",
    "        block_config (list of 4 ints) - how many layers in each pooling block\n",
    "        num_init_features (int) - the number of filters to learn in the first convolution layer\n",
    "        bn_size (int) - multiplicative factor for number of bottle neck layers\n",
    "          (i.e. bn_size * k features in the bottleneck layer)\n",
    "        drop_rate (float) - dropout rate after each dense layer\n",
    "        num_classes (int) - number of classification classes\n",
    "    \"\"\"\n",
    "    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),\n",
    "                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):\n",
    "\n",
    "        super(FCDenseNet, self).__init__()\n",
    "        \n",
    "        self.features = nn.Sequential()\n",
    "        self.features.add_module('firstConv', _FirstConv(num_init_features))\n",
    "\n",
    "        # Each denseblock\n",
    "        num_features = num_init_features\n",
    "        for i, num_layers in enumerate(block_config):\n",
    "            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,\n",
    "                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)\n",
    "            self.features.add_module('denseblock%d' % (i + 1), block)\n",
    "            num_features = num_features + num_layers * growth_rate\n",
    "            if i != len(block_config) - 1:\n",
    "                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)\n",
    "                self.features.add_module('transition%d' % (i + 1), trans)\n",
    "                num_features = num_features // 2\n",
    "\n",
    "        # Final batch norm\n",
    "        self.features.add_module('norm5', nn.BatchNorm2d(num_features))\n",
    "\n",
    "        # Linear layer\n",
    "        self.classifier = nn.Linear(num_features, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        features = self.features(x)\n",
    "        out = F.relu(features, inplace=True)\n",
    "        out = F.avg_pool2d(out, kernel_size=7).view(features.size(0), -1)\n",
    "        out = self.classifier(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Test\n",
    "model = FCDenseNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
